
# 将训练的ResNet+LSTM模型 保存的权重文件 分解为单独的ResNet和单独的LSTM 以及 最后的全连接层
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from models.ResNet_LSTM import Res_LSTM


def net2single(weight_path, weight_name, sample_size, sample_duration, num_classes, arch):
    """
        将训练的ResNet+LSTM模型 保存的权重文件 分解为单独的ResNet和单独的LSTM 以及最后的全连接层
        weight_path - 要分解的网络的权重保存目录
        weight_name - 要分解的网络的权重名
        sample_size - 网络的样本大小(宽=高)
        sample_duration - 可以理解为网络输入数据的帧数
        num_classes - 类别数量
    """
    # 创建网络
    Res50_lstm = Res_LSTM(sample_size=sample_size, sample_duration=sample_duration, num_classes=num_classes, arch=arch)
    
    path = os.path.join(weight_path, weight_name)
    print("load weight from: ", path)
    Res50_lstm.load_state_dict(torch.load(path)) # 载入权重

    # 分离ResNet50 并保存 （没有原ResNet50的全连接层）
    path = os.path.join(weight_path, weight_name[:-4]+"-res.pth")
    print("single ResNet and save:  ", path)
    torch.save(Res50_lstm.resnet.state_dict(), path)

    # 分离LSTM 并保存 
    # resnet50的输出 --> view  -->  LSTM层输入 size=(1, 帧数=16, 2048)
    path = os.path.join(weight_path, weight_name[:-4]+"-lstm.pth")
    print("single LSTM and save:  ", path)
    torch.save(Res50_lstm.lstm.state_dict(), path)

    # 分离最后的全连接层 并保存
    path = os.path.join(weight_path, weight_name[:-4]+"-fc.pth")
    print("single fc and save:  ", path)
    torch.save(Res50_lstm.fc1.state_dict(), path)



def test_single(weight_path, weight_name, sample_size, sample_duration, num_classes, arch):
    """
        测试分离后的权重是否正确
    """
    from dataset import CSL_Isolated
    import time

    print("\n\nTest begin...")

    print("load test data...")
    transform = transforms.Compose([transforms.Resize([sample_size, sample_size]), transforms.ToTensor()])
    dataset = CSL_Isolated(data_path="D:\Works\Hisi_codes\SLRDataset_isolate\color_img",
        label_path="D:\Works\Hisi_codes\SLRDataset_isolate\gloss_label_2.txt", frames=sample_duration,
        num_classes=num_classes, transform=transform)
    input = dataset[0]['data'].unsqueeze(0)


    # 构建网络
    print("create net...")
    Res_lstm_net = Res_LSTM(sample_size=sample_size, sample_duration=sample_duration, num_classes=101, arch=arch)

    # 原始未分解的权重 的输出
    print("original network...")
    path = os.path.join(weight_path, weight_name) # 原始未分解的权重
    Res_lstm_net.load_state_dict(torch.load(path))
    Res_lstm_net.eval()
    t1 = time.time()
    output1 = Res_lstm_net(input)
    t2 = time.time()

    # 加载单独的权重
    print("single network...")
    # 单独的ResNet加载权重
    path = os.path.join(weight_path, weight_name[:-4]+"-res.pth")
    Res_lstm_net.resnet.load_state_dict(torch.load(path))
    # 单独的LSTM加载权重
    path = os.path.join(weight_path, weight_name[:-4]+"-lstm.pth")
    Res_lstm_net.lstm.load_state_dict(torch.load(path))
    # 单独的fc加载权重
    path = os.path.join(weight_path, weight_name[:-4]+"-fc.pth")
    Res_lstm_net.fc1.load_state_dict(torch.load(path))

    # 分解后网络的输出
    Res_lstm_net.eval()
    t3 = time.time()
    output2 = Res_lstm_net(input)
    t4 = time.time()

    print("output1==output2: ", output1.equal(output2))

    # print("max's index: ", output.index(max(output)))
    print("max's index(original network): ", torch.argmax(output1))
    print("max's index( single  network): ", torch.argmax(output2))
    print("eval time:  original network", t2-t1, " single network", t4-t3)
    print()

if __name__ == '__main__':
    # Res50+LSTM 256x256 16帧
    weight_path = r"D:\Works\Hisi_codes\SLR-master\res50lstm_models" # 权重所在目录
    weight_name = "res50lstm_epoch026.pth"  # 权重文件名
    sample_size = 256
    sample_duration = 16
    num_classes = 101
    arch="resnet50"
    
    # # 分解
    # net2single(weight_path, weight_name, sample_size, sample_duration, num_classes, arch)

    # 测试分解后的权重是否正确
    test_single(weight_path, weight_name, sample_size, sample_duration, num_classes, arch)

